from pathlib import Path
from audios.audio_asr import AudioAsr
import pyarrow.parquet as pq
import json
import re


class MalayDataset(object):

    zh_lang = '马来语'
    data_dir = '/root/autodl-tmp/datasets'

    def __init__(self):
        self.asr = AudioAsr()

    def read_text(self, text_file):
        return '\n'.join(open(text_file, 'r', encoding='utf-8').readlines())

    def dataset1(self, top_k: int = 10):
        """
        仓库： Nexdata/Malay_Conversational_Speech_Data_by_Mobile_Phone
        数据： 7 条
        """
        dataset_dir = Path(self.data_dir) / 'Malay_Conversational_Speech_Data_by_Mobile_Phone'
        records = []
        for audio_file in dataset_dir.glob('*.wav'):
            if len(records) == top_k:
                break
            text_file = audio_file.parent / f'{audio_file.stem}.txt'
            records.append({
                'wav_file': audio_file.as_posix(),
                'source_text': self.read_text(text_file),
            })

        results = self.asr.recognize(records, zh_lang='马来语')
        print(json.dumps(results, ensure_ascii=False, indent=4))

    def dataset2(self, top_k: int = 10):
        """
        仓库： malaysia-ai/iban-whisper-format
        数据： 2200 条
        """
        dataset_dir = Path(self.data_dir) / 'iban-whisper-format'
        text_file = dataset_dir / 'iban-dataset.json'

        records = []
        with open(text_file, 'r', encoding='utf-8') as fs:
            datas = json.load(fs)

        reg_tag = re.compile('<\|([^\<\>]*)\|>', re.I)
        reg_text = re.compile('>([^\<\>]+)<', re.I)
        for item in datas[0: top_k]:
            audio_file = dataset_dir / item['filename']
            text = reg_text.search(item['Y']).group(1)
            records.append({
                'wav_file': audio_file.as_posix(),
                'source_text': text,
            })

        results = self.asr.recognize(records, zh_lang='马来语')
        print(json.dumps(results, ensure_ascii=False, indent=4))

    def dataset3(self, top_k: int = 10):
        """
        仓库： mesolitica/Malaysian-STT-Whisper
        数据： 2262 条
        github: https://github.com/mesolitica/malaysian-dataset/tree/master/speech-to-text-semisupervised/distilled-malaysian-whisper
        https://malaysian-dataset.readthedocs.io/en/latest/
        """
        # file = '/root/autodl-tmp/datasets/Malaysian-STT-Whisper/data/malaysian_context-00000-of-00001.parquet'
        # parquet_file = pq.ParquetFile(file)
        # data = parquet_file.read().to_pandas()
        # print(data.head())


if __name__ == '__main__':
    # malay = MalayDataset()
    # malay.dataset1()
    # malay.dataset2()
    file = '/root/autodl-tmp/datasets/Malaysian-STT-Whisper/data/malaysian_context-00000-of-00001.parquet'
    parquet_file = pq.ParquetFile(file)
    data = parquet_file.read().to_pandas()
    print(data.head())